{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tensor2tensor.data_generators import problem\n",
    "from tensor2tensor.data_generators import text_problems\n",
    "from tensor2tensor.data_generators import translate\n",
    "from tensor2tensor.layers import common_attention\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor import problems\n",
    "import tensorflow as tf\n",
    "import os\n",
    "import logging\n",
    "import sentencepiece as spm\n",
    "import transformer_tag\n",
    "from tensor2tensor.layers import modalities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = 'sp10m.cased.t5.model'\n",
    "sp = spm.SentencePieceProcessor()\n",
    "sp.Load(vocab)\n",
    "\n",
    "class Encoder:\n",
    "    def __init__(self, sp):\n",
    "        self.sp = sp\n",
    "        self.vocab_size = sp.GetPieceSize() + 100\n",
    "\n",
    "    def encode(self, s):\n",
    "        return self.sp.EncodeAsIds(s)\n",
    "\n",
    "    def decode(self, ids, strip_extraneous = False):\n",
    "        return self.sp.DecodeIds(list(ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = [\n",
    "    {'class': 0, 'Description': 'PAD', 'salah': '', 'betul': ''},\n",
    "    {\n",
    "        'class': 1,\n",
    "        'Description': 'kesambungan subwords',\n",
    "        'salah': '',\n",
    "        'betul': '',\n",
    "    },\n",
    "    {'class': 2, 'Description': 'tiada kesalahan', 'salah': '', 'betul': ''},\n",
    "    {\n",
    "        'class': 3,\n",
    "        'Description': 'kesalahan frasa nama, Perkara yang diterangkan mesti mendahului \"penerang\"',\n",
    "        'salah': 'Cili sos',\n",
    "        'betul': 'sos cili',\n",
    "    },\n",
    "    {\n",
    "        'class': 4,\n",
    "        'Description': 'kesalahan kata jamak',\n",
    "        'salah': 'mereka-mereka',\n",
    "        'betul': 'mereka',\n",
    "    },\n",
    "    {\n",
    "        'class': 5,\n",
    "        'Description': 'kesalahan kata penguat',\n",
    "        'salah': 'sangat tinggi sekali',\n",
    "        'betul': 'sangat tinggi',\n",
    "    },\n",
    "    {\n",
    "        'class': 6,\n",
    "        'Description': 'kata adjektif dan imbuhan \"ter\" tanpa penguat.',\n",
    "        'salah': 'Sani mendapat markah yang tertinggi sekali.',\n",
    "        'betul': 'Sani mendapat markah yang tertinggi.',\n",
    "    },\n",
    "    {\n",
    "        'class': 7,\n",
    "        'Description': 'kesalahan kata hubung',\n",
    "        'salah': 'Sally sedang membaca bila saya tiba di rumahnya.',\n",
    "        'betul': 'Sally sedang membaca apabila saya tiba di rumahnya.',\n",
    "    },\n",
    "    {\n",
    "        'class': 8,\n",
    "        'Description': 'kesalahan kata bilangan',\n",
    "        'salah': 'Beribu peniaga tidak membayar cukai pendapatan.',\n",
    "        'betul': 'Beribu-ribu peniaga tidak membayar cukai pendapatan',\n",
    "    },\n",
    "    {\n",
    "        'class': 9,\n",
    "        'Description': 'kesalahan kata sendi',\n",
    "        'salah': 'Umar telah berpindah daripada sekolah ini bulan lalu.',\n",
    "        'betul': 'Umar telah berpindah dari sekolah ini bulan lalu.',\n",
    "    },\n",
    "    {\n",
    "        'class': 10,\n",
    "        'Description': 'kesalahan penjodoh bilangan',\n",
    "        'salah': 'Setiap orang pelajar',\n",
    "        'betul': 'Setiap pelajar.',\n",
    "    },\n",
    "    {\n",
    "        'class': 11,\n",
    "        'Description': 'kesalahan kata ganti diri',\n",
    "        'salah': 'Pencuri itu telah ditangkap. Beliau dibawa ke balai polis.',\n",
    "        'betul': 'Pencuri itu telah ditangkap. Dia dibawa ke balai polis.',\n",
    "    },\n",
    "    {\n",
    "        'class': 12,\n",
    "        'Description': 'kesalahan ayat pasif',\n",
    "        'salah': 'Cerpen itu telah dikarang oleh saya.',\n",
    "        'betul': 'Cerpen itu telah saya karang.',\n",
    "    },\n",
    "    {\n",
    "        'class': 13,\n",
    "        'Description': 'kesalahan kata tanya',\n",
    "        'salah': 'Kamu berasal dari manakah ?',\n",
    "        'betul': 'Kamu berasal dari mana ?',\n",
    "    },\n",
    "    {\n",
    "        'class': 14,\n",
    "        'Description': 'kesalahan tanda baca',\n",
    "        'salah': 'Kamu berasal dari manakah .',\n",
    "        'betul': 'Kamu berasal dari mana ?',\n",
    "    },\n",
    "    {\n",
    "        'class': 15,\n",
    "        'Description': 'kesalahan kata kerja tak transitif',\n",
    "        'salah': 'Dia kata kepada saya',\n",
    "        'betul': 'Dia berkata kepada saya',\n",
    "    },\n",
    "    {\n",
    "        'class': 16,\n",
    "        'Description': 'kesalahan kata kerja tak transitif',\n",
    "        'salah': 'Dia kata kepada saya',\n",
    "        'betul': 'Dia berkata kepada saya',\n",
    "    },\n",
    "    {\n",
    "        'class': 17,\n",
    "        'Description': 'kesalahan kata kerja transitif',\n",
    "        'salah': 'Dia suka baca buku',\n",
    "        'betul': 'Dia suka membaca buku',\n",
    "    },\n",
    "    {\n",
    "        'class': 18,\n",
    "        'Description': 'penggunaan kata yang tidak tepat',\n",
    "        'salah': 'Tembuk Besar negeri Cina dibina oleh Shih Huang Ti.',\n",
    "        'betul': 'Tembok Besar negeri Cina dibina oleh Shih Huang Ti',\n",
    "    },\n",
    "    {\n",
    "        'class': 19,\n",
    "        'Description': 'kesalahan frasa kerja tak transitif',\n",
    "        'salah': 'berdasarkan pada keterangan ini',\n",
    "        'betul': 'berdasarkan keterangan ini',\n",
    "    },\n",
    "    {\n",
    "        'class': 20,\n",
    "        'Description': 'kesalahan frasa kerja transitif',\n",
    "        'salah': 'Dia membeli banyak buah',\n",
    "        'betul': 'Dia banyak membeli buah',\n",
    "    },\n",
    "    {\n",
    "        'class': 21,\n",
    "        'Description': 'kesalahan frasa kerja pasif',\n",
    "        'salah': 'Surat itu saga akan balas',\n",
    "        'betul': 'Surat itu akan saga balas',\n",
    "    },\n",
    "]\n",
    "\n",
    "\n",
    "class Tatabahasa:\n",
    "    def __init__(self, d):\n",
    "        self.d = d\n",
    "        self.kesalahan = {i['Description']: no for no, i in enumerate(self.d)}\n",
    "        self.reverse_kesalahan = {v: k for k, v in self.kesalahan.items()}\n",
    "        self.vocab_size = len(self.d)\n",
    "\n",
    "    def encode(self, s):\n",
    "        return [self.kesalahan[i] for i in s]\n",
    "\n",
    "    def decode(self, ids, strip_extraneous = False):\n",
    "        return [self.reverse_kesalahan[i] for i in ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@registry.register_problem\n",
    "class Grammar(text_problems.Text2TextProblem):\n",
    "    \"\"\"grammatical error correction.\"\"\"\n",
    "\n",
    "    def feature_encoders(self, data_dir):\n",
    "        encoder = Encoder(sp)\n",
    "        t = Tatabahasa(d)\n",
    "        return {'inputs': encoder, 'targets': encoder, 'targets_error_tag': t}\n",
    "\n",
    "    def hparams(self, defaults, model_hparams):\n",
    "        super(Grammar, self).hparams(defaults, model_hparams)\n",
    "        if 'use_error_tags' not in model_hparams:\n",
    "            model_hparams.add_hparam('use_error_tags', True)\n",
    "        if 'middle_prediction' not in model_hparams:\n",
    "            model_hparams.add_hparam('middle_prediction', False)\n",
    "        if 'middle_prediction_layer_factor' not in model_hparams:\n",
    "            model_hparams.add_hparam('middle_prediction_layer_factor', 2)\n",
    "        if 'ffn_in_prediction_cascade' not in model_hparams:\n",
    "            model_hparams.add_hparam('ffn_in_prediction_cascade', 1)\n",
    "        if 'error_tag_embed_size' not in model_hparams:\n",
    "            model_hparams.add_hparam('error_tag_embed_size', 12)\n",
    "        if model_hparams.use_error_tags:\n",
    "            defaults.modality[\n",
    "                'targets_error_tag'\n",
    "            ] = modalities.ModalityType.SYMBOL\n",
    "            error_tag_vocab_size = self._encoders[\n",
    "                'targets_error_tag'\n",
    "            ].vocab_size\n",
    "            defaults.vocab_size['targets_error_tag'] = error_tag_vocab_size\n",
    "\n",
    "    def example_reading_spec(self):\n",
    "        data_fields, _ = super(Seq2edits, self).example_reading_spec()\n",
    "        data_fields['targets_error_tag'] = tf.VarLenFeature(tf.int64)\n",
    "        return data_fields, None\n",
    "\n",
    "    @property\n",
    "    def approx_vocab_size(self):\n",
    "        return 32100\n",
    "\n",
    "    @property\n",
    "    def is_generate_per_split(self):\n",
    "        return False\n",
    "\n",
    "    @property\n",
    "    def dataset_splits(self):\n",
    "        return [\n",
    "            {'split': problem.DatasetSplit.TRAIN, 'shards': 200},\n",
    "            {'split': problem.DatasetSplit.EVAL, 'shards': 1},\n",
    "        ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.system('mkdir t2t-tatabahasa/train-small')\n",
    "DATA_DIR = os.path.expanduser('t2t-tatabahasa/data')\n",
    "TMP_DIR = os.path.expanduser('t2t-tatabahasa/tmp')\n",
    "TRAIN_DIR = os.path.expanduser('t2t-tatabahasa/train-small')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROBLEM = 'grammar'\n",
    "t2t_problem = problems.problem(PROBLEM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = 'transformer_tag'\n",
    "HPARAMS = 'transformer_base'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils.trainer_lib import create_run_config, create_experiment\n",
    "from tensor2tensor.utils.trainer_lib import create_hparams\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor import models\n",
    "from tensor2tensor import problems\n",
    "from tensor2tensor.utils import trainer_lib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    }
   ],
   "source": [
    "X = tf.placeholder(tf.int32, [None, None], name = 'x_placeholder')\n",
    "Y = tf.placeholder(tf.int32, [None, None], name = 'y_placeholder')\n",
    "targets_error_tag = tf.placeholder(tf.int32, [None, None], 'error_placeholder')\n",
    "X_seq_len = tf.count_nonzero(X, 1, dtype=tf.int32)\n",
    "maxlen_decode = tf.reduce_max(X_seq_len)\n",
    "\n",
    "x = tf.expand_dims(tf.expand_dims(X, -1), -1)\n",
    "y = tf.expand_dims(tf.expand_dims(Y, -1), -1)\n",
    "targets_error_tag_ = tf.expand_dims(tf.expand_dims(targets_error_tag, -1), -1)\n",
    "\n",
    "features = {\n",
    "    \"inputs\": x,\n",
    "    \"targets\": y,\n",
    "    \"target_space_id\": tf.constant(1, dtype=tf.int32),\n",
    "    'targets_error_tag': targets_error_tag,\n",
    "}\n",
    "Modes = tf.estimator.ModeKeys\n",
    "hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams.filter_size = 2048\n",
    "hparams.hidden_size = 512\n",
    "hparams.num_heads = 8\n",
    "hparams.num_hidden_layers = 6\n",
    "hparams.num_decoder_layers = hparams.num_hidden_layers\n",
    "hparams.num_encoder_layers = hparams.num_hidden_layers \n",
    "hparams.vocab_divisor = 128\n",
    "hparams.label_smoothing = 0.0\n",
    "hparams.shared_embedding_and_softmax_weights = False\n",
    "hparams.dropout = 0.1\n",
    "hparams.max_length = 1024\n",
    "hparams.multiproblem_mixing_schedule = \"pretrain\"\n",
    "\n",
    "hparams.optimizer = \"Adafactor\"\n",
    "hparams.learning_rate_warmup_steps = 10000\n",
    "hparams.learning_rate_schedule = \"rsqrt_decay\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'infer'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'infer'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.label_smoothing to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.label_smoothing to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n"
     ]
    }
   ],
   "source": [
    "model = registry.model(MODEL)(hparams, Modes.PREDICT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# logits = model(features)\n",
    "# logits\n",
    "\n",
    "# sess = tf.InteractiveSession()\n",
    "# sess.run(tf.global_variables_initializer())\n",
    "# l = sess.run(logits, feed_dict = {X: [[10,10, 10, 10,10,1],[10,10, 10, 10,10,1]],\n",
    "#                              Y: [[10,10, 10, 10,10,1],[10,10, 10, 10,10,1]],\n",
    "#                              targets_error_tag: [[10,10, 10, 10,10,1],\n",
    "#                                                 [10,10, 10, 10,10,1]]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/layers/common_attention.py:931: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/layers/common_attention.py:931: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/models/transformer.py:96: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/models/transformer.py:96: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/expert_utils.py:621: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/expert_utils.py:621: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/expert_utils.py:621: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/expert_utils.py:621: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/b2b/transformer_tag.py:1164: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/b2b/transformer_tag.py:1164: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    }
   ],
   "source": [
    "features = {\n",
    "    \"inputs\": x,\n",
    "    \"target_space_id\": tf.constant(1, dtype=tf.int32),\n",
    "}\n",
    "\n",
    "with tf.variable_scope(tf.get_variable_scope(), reuse = False):\n",
    "    fast_result = model._greedy_infer(features, maxlen_decode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'t2t-tatabahasa/train-small/model.ckpt-200'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ckpt_path = tf.train.latest_checkpoint(os.path.join(TRAIN_DIR))\n",
    "ckpt_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from t2t-tatabahasa/train-small/model.ckpt-200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from t2t-tatabahasa/train-small/model.ckpt-200\n"
     ]
    }
   ],
   "source": [
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, ckpt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('../pure-text/dataset-tatabahasa.pkl', 'rb') as fopen:\n",
    "    data = pickle.load(fopen)\n",
    "\n",
    "encoder = Encoder(sp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_xy(row, encoder):\n",
    "    x, y, tag = [], [], []\n",
    "\n",
    "    for i in range(len(row[0])):\n",
    "        t = encoder.encode(row[0][i][0])\n",
    "        y.extend(t)\n",
    "        t = encoder.encode(row[1][i][0])\n",
    "        x.extend(t)\n",
    "        tag.extend([row[1][i][1]] * len(t))\n",
    "\n",
    "    # EOS\n",
    "    x.append(1)\n",
    "    y.append(1)\n",
    "    tag.append(0)\n",
    "\n",
    "    return x, y, tag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y, tag = get_xy(data[0], encoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = sess.run(fast_result, \n",
    "         feed_dict = {X: [x]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2,\n",
       "        2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 3, 3, 2, 2, 0]])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r['outputs_tag']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 3, 3, 2, 2, 0])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  104,  6892,  3208,    13,    13,    25,     7,   749,    36,\n",
       "           15,     6, 15277,   844,    13,   564, 15277,   844,    13,\n",
       "          564,    15,     4,  2083,   417,   727,  4073,    15,     5,\n",
       "           34,   394,   648,   714,  1337,    17,   798,    18,  3481,\n",
       "         4963,    15,     3,     1]])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r['outputs']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Dirk Janaas-Jan \" Huntelaar Huntelaar ( lahir 12 Ogos 1983 ) merupakan pemain bola sepak Belanda yang bermain di posisi penyerang .'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder.decode(r['outputs'][0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Dirk Jan Klaas \" Klaas-Jan \" Huntelaar ( lahir 12 Ogos 1983 ) merupakan bola sepak Belanda pemain yang bermain di penyerang posisi .'"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder.decode(x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
